Conv2dTranspose
计算二维转置卷积,可以视为 Conv2d 对输入求梯度,也称为反卷积(实际不是真正的反卷积)。
输入的 shape 通常为 \((N, H_{in}, W_{in}, C_{in})\),其中:
\(N\) 是 batch size
\(C_{in}\) 是空间维度
\(H_{in}, W_{in}\) 分别为特征层的高度和宽度
- 输入:
input_x - 输入数据的地址
input_w - 输入卷积核权重的地址
bias - 输入偏置的地址
param - 算子计算所需参数的结构体。其各成员见下述。
core_mask - 核掩码。
ConvTransposeParameter定义:
1typedef struct ConvTransposeParameter {
2 void* workspace_; // 用于存放中间计算结果
3 int output_batch_; // 输出数据总批次
4 int input_batch_; // 输入数据总批次
5 int input_h_; // 输入数据h维度大小
6 int input_w_; // 输入数据w维度大小
7 int output_h_; // 输出数据h维度大小
8 int output_w_; // 输出数据w维度大小
9 int input_channel_; // 输入数据通道数
10 int output_channel_; // 输出数据通道数
11 int kernel_h_; // 卷积核h维度大小
12 int kernel_w_; // 卷积核w维度大小
13 int group_; // 组数
14 int pad_l_; // 左填充大小
15 int pad_u_; // 上填充大小
16 int dilation_h_; // 卷积核h维度膨胀尺寸大小
17 int dilation_w_; // 卷积核w维度膨胀尺寸大小
18 int stride_h_; // 卷积核h维度步长
19 int stride_w_; // 卷积核w维度步长
20 int buffer_size_; // 为分块计算所分配的缓存大小
21} ConvTransposeParameter;
- 输出:
out_y - 输出地址。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持int8, fp32
MT7004 支持fp16, fp32
共享存储版本:
-
void i8_convtranspose_s(int8_t *input_x, int8_t *input_w, int8_t *out_y, int *bias, ConvTransposeParameter *conv_param, int core_mask)
-
void hp_convtranspose_s(half *input_x, half *input_w, half *out_y, half *bias, ConvTransposeParameter *conv_param, int core_mask)
-
void fp_convtranspose_s(float *input_x, float *input_w, float *out_y, float *bias, ConvTransposeParameter *conv_param, int core_mask)
C调用示例:
1void TestConvTransposeSMCFp32(int* input_shape, int* weight_shape, int* output_shape, int* stride, int* padding, int* dilation, int groups, float* bias, int core_mask) {
2 int core_id = get_core_id();
3 int logic_core_id = GetLogicCoreId(core_mask, core_id);
4 int core_num = GetCoreNum(core_mask);
5 float* input_data = (float*)0x88000000;
6 float* weight = (float*)0x89000000;
7 float* output_data = (float*)0x90000000;
8 float* bias_data = (float*)0x91000000;
9 float* check = (float*)0x94000000;
10 ConvTransposeParameter* param = (ConvTransposeParameter*)0x92000000;
11 if (logic_core_id == 0) {
12 memcpy(bias_data, bias, sizeof(float) * output_shape[3]);
13 memset(output_data, 0, output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * sizeof(float));
14 memset(check, 0, output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * sizeof(float));
15 param->dilation_h_ = dilation[0];
16 param->dilation_w_ = dilation[1];
17 param->group_ = groups;
18 param->input_batch_ = input_shape[0];
19 param->input_h_ = input_shape[1];
20 param->input_w_ = input_shape[2];
21 param->input_channel_ = input_shape[3];
22 param->kernel_h_ = weight_shape[1];
23 param->kernel_w_ = weight_shape[2];
24 param->output_batch_ = output_shape[0];
25 param->output_h_ = output_shape[1];
26 param->output_w_ = output_shape[2];
27 param->output_channel_ = output_shape[3];
28 param->stride_h_ = stride[0];
29 param->stride_w_ = stride[0];
30 param->pad_u_ = padding[0];
31 param->pad_l_ = padding[2];
32 param->workspace_ = (float*)0xA0000000;
33 }
34 sys_bar(0, core_num); // 初始化参数完成后进行同步
35 fp_convtranspose_s(input_data, weight, output_data, bias_data, param, core_mask);
36}
37
38void main(){
39 int in_channel = 6;
40 int out_channel = 6;
41 int groups = 6;
42 int input_shape[4] = {2, 5, 7, in_channel}; // NHWC
43 int weight_shape[4] = {in_channel, 3, 3, out_channel / groups};
44 int output_shape[4] = {2, 7, 9, out_channel}; // NHWC
45 int stride[2] = {1, 1};
46 int padding[4] = {0, 0, 0, 0};
47 int dilation[2]= {1, 1};
48 float bias[] = {0, 0, 0, 0, 0, 0};
49 int core_mask = 0b1111;
50 TestConvTransposeSMCFp32(input_shape, weight_shape, output_shape, stride, padding, dilation, groups, bias, core_mask);
51}
私有存储版本:
-
void i8_convtranspose_p(int8_t *input_x, int8_t *input_w, int8_t *out_y, int *bias, ConvTransposeParameter *conv_param, int core_mask)
-
void hp_convtranspose_p(half *input_x, half *input_w, half *out_y, half *bias, ConvTransposeParameter *conv_param, int core_mask)
-
void fp_convtranspose_p(float *input_x, float *input_w, float *out_y, float *bias, ConvTransposeParameter *conv_param, int core_mask)
C调用示例:
1void TestConvTransposeL2Fp32(int* input_shape, int* weight_shape, int* output_shape, int* stride, int* padding, int* dilation, int groups, float* bias, int core_mask) {
2 float* input_data = (float*)0x10000000; // 私有存储版本地址设置在AM内
3 float* weight = (float*)0x10001000;
4 float* output_data = (float*)0x10002000;
5 float* bias_data = (float*)0x10003000;
6 float* check = (float*)0x10004000;
7 ConvTransposeParameter* param = (ConvTransposeParameter*)0x10005000;
8 memcpy(bias_data, bias, sizeof(float) * output_shape[3]);
9 memset(output_data, 0, output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * sizeof(float));
10 memset(check, 0, output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * sizeof(float));
11 param->dilation_h_ = dilation[0];
12 param->dilation_w_ = dilation[1];
13 param->group_ = groups;
14 param->input_batch_ = input_shape[0];
15 param->input_h_ = input_shape[1];
16 param->input_w_ = input_shape[2];
17 param->input_channel_ = input_shape[3];
18 param->kernel_h_ = weight_shape[1];
19 param->kernel_w_ = weight_shape[2];
20 param->output_batch_ = output_shape[0];
21 param->output_h_ = output_shape[1];
22 param->output_w_ = output_shape[2];
23 param->output_channel_ = output_shape[3];
24 param->stride_h_ = stride[0];
25 param->stride_w_ = stride[0];
26 param->pad_u_ = padding[0];
27 param->pad_l_ = padding[2];
28 param->workspace_ = (float*)0x10006000;
29 param->buffer_size_ = 1024; // 私有存储版本中,必须设置该参数,用于确定分块计算的大小
30 fp_convtranspose_p(input_data, weight, output_data, bias_data, param, core_mask);
31}
32
33void main(){
34 int in_channel = 6;
35 int out_channel = 6;
36 int groups = 6;
37 int input_shape[4] = {2, 5, 7, in_channel}; // NHWC
38 int weight_shape[4] = {in_channel, 3, 3, out_channel / groups};
39 int output_shape[4] = {2, 7, 9, out_channel}; // NHWC
40 int stride[2] = {1, 1};
41 int padding[4] = {0, 0, 0, 0};
42 int dilation[2]= {1, 1};
43 float bias[] = {0, 0, 0, 0, 0, 0};
44 int core_mask = 0b0001; // 私有存储版本只能设置为一个核心启动
45 TestConvTransposeL2Fp32(input_shape, weight_shape, output_shape, stride, padding, dilation, groups, bias, core_mask);
46}